iT邦幫忙

2025 iThome 鐵人賽

DAY 21
0
AI & Data

零基礎 AI 入門!從 Wx+b 到熱門模型的完整之路!系列 第 21

【Day 21】從 Wx+b 到能寫詩的模型GPT-2 的煉成

  • 分享至 

  • xImage
  •  

前言

今天我們來聊聊 GPT 模型的架構,特別是現在很常見、也很實用的「Decoder-only」設計。這類模型其實已經在各種任務上展現出超強的能力,無論是生成長篇文章、聊天對話,甚至是寫程式,都有非常不錯的表現。

所以今天我們就從 GPT-2 的基本設計開始,一步步帶大家拆解這種架構到底怎麼組成、有哪些地方容易踩雷,又有哪些訓練技巧是真的有幫助的。我們不會去講太多花俏的設計,而是回到最小可行架構希望讓大家可以從底層真正搞懂這個模型的原理,也能在實作的時候少走一點冤枉路。

GPT 是什麼?從 GPT-2 講起

GPT-2,全名是 Generative Pre-trained Transformer 2,它在自然語言處理(NLP)這個領域裡可以說是一個超重要的里程碑。

它雖然跟 Google 的 BERT 一樣,都是基於 Transformer 架構打造出來的模型,但它們的設計邏輯其實大不相同。BERT 的重點是「理解語意」,所以它會從前後兩邊同時讀取文字,透過所謂的「雙向編碼」來預測句子中被遮蔽的詞語。簡單說,它是在考你對上下文的理解力。

但 GPT-2 玩的是另一種套路。它的策略是「自回歸生成」,也就是從左到右一個詞一個詞慢慢生出來。這樣的方式,就像人類在寫東西時,一邊想、一邊打字的邏輯流動。因為它是按順序產出語句,所以在生成像小說、聊天對話、甚至程式碼時,它的自然度跟創造力都表現得非常強。

GPT 的目標不是理解文字,而是要創作,而作這件事本來就不是先知道全部再倒推,而是像人類一樣,一步步寫下去。

GPT 的預訓練任務

GPT-2 訓練的核心任務叫做 自回歸語言建模(Causal Language Modeling),意思是它要學會預測下一個詞會是什麼,舉例來說,給它一串文字 [x₁, x₂, …, xₙ],它的工作是學會在每個時間點預測下一個 token 的機率。這用數學式子表示如下:

P(x₁, x₂, …, xₙ) = ∏ P(xᵢ | x₁, x₂, …, xᵢ₋₁)

這句話翻成白話就是每個詞的出現只能根據它前面那些詞來判斷,不能偷看後面還沒出現的內容。這種規則也就是自回歸的本質。而這樣才貼近人類書寫時的真實狀況。當我們在打字時,是不知道未來幾個字會怎樣寫的,我們只能根據現在的語境去決定下一步。

GPT 的模型結構

雖然 GPT 和 BERT 都是建立在 Transformer 這個架構之上,但其實它們對這個原始設計並沒有大刀闊斧地改造,基本骨架幾乎一模一樣。大多數的變化,其實只是一些細節上的調整。以 GPT 為例最主要的幾個修改包括:使用了可學習的位置編碼(learnable position embeddings),還有把 LayerNorm 的位置做了調整。

我們先來看 Transformer 原始的設計。在 2017 年的那篇經典論文中,每一層的處理流程大概是這樣的:

x = x + Sublayer(x)  
x = LayerNorm(x)

這叫做 Post-LN 架構,意思是模組處理完後,再加上原來的輸入最後做 Layer Normalization。這樣做的好處是訓練初期穩定,不容易一開始就亂掉。但隨著模型越來越深,比如 GPT-2 有幾十層那種深度,就發現這種設計會在訓練後期出現 梯度消失 的問題模型變得學不動了。

所以 GPT-2 改用了一種叫做 Pre-LN 架構 的設計,它把 LayerNorm 移到一開始,流程變成這樣:

y = x + Sublayer(LayerNorm(x))

這個改動讓模型在非常深的情況下還能保持穩定,也比較容易訓練得好。這也是為什麼 GPT-2 能做出像 1.5 億個參數、甚至超過 40 層深度的大型模型,還能有效運作。

你可能會想為什麼 LayerNorm 的位置會影響這麼大?因為 LayerNorm 是在調整訊號的穩定性。如果等模組跑完才正規化,深層模型可能會累積太多訊號雜訊,最後訓練失效。反過來,一開始就做正規化,會讓整個訊號流程更穩。

Wx + b 就能造出 GPT-2?

如果你對 Transformer 架構有點概念,那 GPT-2 的設計應該不會太陌生。這邊就不從頭細講整個架構了,我們挑幾個比較核心的部分來聊一下。先來看看 HuggingFace 的 transformers 套件裡 GPT-2 的模型結構,大致上是這樣:

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

1. Attention

在 GPT 的 Attention 設計裡,其實跟原本 Transformer Decoder 的做法是一樣的,使用的是 causal attention。這種設計的關鍵在於模型在預測下一個詞的時候,只能看到它前面的詞,不能偷看後面的內容,這樣才能符合語言生成的因果順序。而在 GPT 的實作裡,Attention 中的 Q、K、V 是透過一個叫 c_attn 的模組來計算的,輸出結果則是透過 c_proj 來處理。這兩個部分,其實本質上都是用一個叫 Conv1D 的模組來實作的。

不過這裡的 Conv1D 有點容易讓人誤會。它名字裡雖然有Conv,但其實跟我們在 CNN 裡學到的那種一維卷積完全不一樣。這裡的 Conv1D 其實就是一個線性層,本質上就是做一個矩陣乘法加上偏置,所以它不是真的做卷積,而是把輸入的向量轉成我們需要的維度。

class Conv1D(nn.Module):
    def __init__(self, nf, nx):
        super().__init__()
        self.nf = nf
        self.weight = nn.Parameter(torch.empty(nx, nf))
        self.bias = nn.Parameter(torch.zeros(nf))
        nn.init.normal_(self.weight, mean=0.0, std=0.02)

    def forward(self, x):
        # x: [..., nx]
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(size_out)
        return x

而GPT-2 把 QKV 和 O 分開處理(用 c_attn 處理 QKV,用 c_proj 處理 O)其實還有個很實用的好處,就是做 hook 或分析模型的時候方便很多。如果你只是想抓出來看看模型在跑的時候產生的 Q、K、V 值,那只要 hook 一下 attn.c_attn 這個模組就好,不但寫起來簡單、程式碼也比較乾淨好維護。反之如果你只關心最後注意力的輸出(也就是 O),那就可以直接 hook attn.c_proj

class GPT2Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        nx = config.n_embd
        n_head = config.n_head
        if nx % n_head != 0:
            raise ValueError("n_embd must be divisible by n_head")
        self.n_head = n_head
        self.head_dim = nx // n_head
        self.scale_attn = 1.0 / math.sqrt(self.head_dim)

        # c_attn projects to q, k, v concatenated
        self.c_attn = Conv1D(3 * nx, nx)
        self.c_proj = Conv1D(nx, nx)

        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)

        # Register a causal mask buffer up to max positions
        max_pos = config.n_positions
        mask = torch.tril(torch.ones((max_pos, max_pos), dtype=torch.bool))
        self.register_buffer("causal_mask", mask[None, None, :, :], persistent=False)  # [1,1,T,T]

    def _split_heads(self, x):
        # x: [B, T, n_embd] -> [B, n_head, T, head_dim]
        B, T, C = x.size()
        x = x.view(B, T, self.n_head, self.head_dim).permute(0, 2, 1, 3)
        return x

    def _merge_heads(self, x):
        # x: [B, n_head, T, head_dim] -> [B, T, n_embd]
        x = x.permute(0, 2, 1, 3).contiguous()
        B, T, _, _ = x.size()
        return x.view(B, T, self.n_head * self.head_dim)

    def forward(self, x, attention_mask=None):
        B, T, _ = x.size()

        qkv = self.c_attn(x)  # [B, T, 3*n_embd]
        q, k, v = qkv.split(qkv.size(-1) // 3, dim=2)

        q = self._split_heads(q)  # [B, h, T, hd]
        k = self._split_heads(k)  # [B, h, T, hd]
        v = self._split_heads(v)  # [B, h, T, hd]

        attn_scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale_attn  # [B,h,T,T]

        # Causal mask
        attn_scores = attn_scores.masked_fill(self.causal_mask[:, :, :T, :T] == 0, float("-inf"))

        # Additive attention mask [B,1,1,T], if provided
        if attention_mask is not None:
            attn_scores = attn_scores + attention_mask  # broadcast on last dim

        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.attn_dropout(attn_probs)

        context = torch.matmul(attn_probs, v)  # [B,h,T,hd]
        context = self._merge_heads(context)   # [B,T,n_embd]
        out = self.c_proj(context)
        out = self.resid_dropout(out)
        return out

2. FFN

而 FFN 在 GPT-2 裡是透過 GPT2MLP 這個類別來實作的。不過要特別注意GPT-2 採用的是 Pre-LN 架構,也就是說,在進入 FFN(以及 Self-Attention)之前,會先做 LayerNorm,而不是像某些其他模型那樣把 LayerNorm 放在最後。

class GPT2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        nx = config.n_embd
        # HF uses intermediate size = 4 * n_embd by default (config.n_inner may override)
        n_inner = getattr(config, "n_inner", None) or 4 * nx
        self.c_fc = Conv1D(n_inner, nx)
        self.c_proj = Conv1D(nx, n_inner)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

3. GPT2Block

GPT2Block 所採用的其實正是 Pre-LN 架構,其實現方式並不複雜,基本延續了我們先前所構建的處理流程,只是在每個子模組執行前引入 LayerNorm 而已。

class GPT2Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        eps = getattr(config, "layer_norm_epsilon", 1e-5)  # HF key
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=eps)
        self.attn = GPT2Attention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd, eps=eps)
        self.mlp = GPT2MLP(config)

    def forward(self, x, attention_mask=None):
        x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
        x = x + self.mlp(self.ln_2(x))
        return x

4. GPT2Model

在最終階段我們構建了 GPT2Model 架構,開頭部分的 wte(word token embedding)負責將輸入的 token 映射到向量空間。這裡的 50257 是 GPT-2 的詞彙表大小,表示模型能識別的 token 總數,而 768 則代表每個 token 的向量維度。

接下來是 wpe(position embedding),它負責加入位置資訊。與原始 Transformer 採用固定的正弦位置編碼不同,GPT-2 選擇了 可訓練的嵌入向量,也就是說每個位置都有一個參數化的向量,能夠在訓練過程中學習序列中位置的語義特徵。預設最大長度為 1024,表示這個模型最多能處理 1024 個 token 的輸入長度。這兩者加總後會經過 dropout 層,再傳入一連串的 GPT2Block

class GPT2Model(nn.Module):
    """
    Matches HF GPT2Model module/param names:
      - wte, wpe, h.{i}.attn.{c_attn,c_proj}, h.{i}.ln_1, h.{i}.mlp.{c_fc,c_proj}, ln_f
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)])
        eps = getattr(config, "layer_norm_epsilon", 1e-5)
        self.ln_f = nn.LayerNorm(config.n_embd, eps=eps)
        
    def forward(self, input_ids, attention_mask=None, output_hidden_states=False, return_dict=False):
        B, T = input_ids.size()
        if T > self.config.n_positions:
            raise ValueError(f"Sequence length {T} exceeds n_positions {self.config.n_positions}")

        # Positions
        pos = torch.arange(T, device=input_ids.device, dtype=torch.long).unsqueeze(0).expand(B, T)

        # Embeddings
        x = self.wte(input_ids) + self.wpe(pos)
        x = self.drop(x)

        # Attention mask -> additive [B,1,1,T]
        ext_mask = _make_extended_attn_mask(attention_mask, x.dtype) if attention_mask is not None else None

        all_hidden_states = [] if output_hidden_states else None
        for block in self.h:
            if output_hidden_states:
                all_hidden_states.append(x)
            x = block(x, attention_mask=ext_mask)
        x = self.ln_f(x)
        if output_hidden_states:
            all_hidden_states.append(x)

        if return_dict:
            return {"last_hidden_state": x, "hidden_states": all_hidden_states}
        return (x, all_hidden_states)

到目前為止我們已經完成了模型的主體結構,不過很明顯還少了一個關鍵部分,輸出層現在模型僅產生了 hidden states,也就是最後一層 Decoder 的隱表示。但這些向量本身還不能直接對應到語言輸出。為了讓模型能夠預測下一個詞或 token,還需要一個額外的線性層,將 hidden states 投影回詞彙表的大小,從而生成 logits 分佈。

換句話說我們還少了一個 詞彙投影層(language modeling head),它負責將隱藏狀態轉換為每個 token 的機率分佈,這才是實際生成文字的關鍵步驟。

5.GPT2LMHeadModel

有個滿關鍵的細節是,lm_head 的權重其實是跟 wte(也就是輸入的詞嵌入層)共用的。這種做法叫做 weight tying,簡單來說就是把輸入跟輸出用同一組權重。這樣不只可以大幅減少模型的參數量,也能讓學習過程更穩定。而 lm_head 這一層,正是模型用來產生最終文字的那個head

class GPT2LMHeadModel(nn.Module):
    """
    Matches HF GPT2LMHeadModel heads and weight tying:
      - transformer (GPT2Model)
      - lm_head.weight tied to transformer.wte.weight
    """
    def __init__(self, config):
        super().__init__()
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # tie weights
        self.lm_head.weight = self.transformer.wte.weight

    # HF API helpers
    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

接著在 forward 函數裡,模型會先算出最後一層的 hidden states,然後通過 lm_head 把它轉換成 logits,也就是每個 token 對應所有詞彙的預測分數。通常我們在訓練的時候,會對 logits 和 labels 做個「右移」對齊,這樣模型才能學會預測「下一個」字。

    def forward(self, input_ids, attention_mask=None, labels=None, output_hidden_states=False, return_dict=False):
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )
        hidden_states = outputs["last_hidden_state"]
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Shift for next-token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if return_dict:
            return {"loss": loss, "logits": logits, "hidden_states": outputs["hidden_states"]}
        return (loss, logits, outputs["hidden_states"])

最後如果我們要讓模型產生文字,最終輸出的 logits 是個三維的張量,形狀是 [batch_size, sequence_length, vocab_size],每個位置都表示那個時間點上,每個詞出現的機率。然後就可以用像是取最大值或是抽樣的方法,從這些機率裡選出最有可能的下一個字,完成一整段的生成。

下集預告

今天我們從 GPT-2 的基礎設計一路拆解到整個模型的實作細節,應該可以感受到Decoder-only架構雖然看起來簡單,背後其實藏了不少設計巧思。那明天我們要來換個口味,實際動手做一個簡單但實用的應用場景:語言翻譯任務。這個任務看似老派但正因為夠直觀,也能夠比對與seq2seq的差異,那麼我們明天再見!


上一篇
【Day 20】Decoder 為何會胡說八道 Transformer 的生成機制與幻覺真相
下一篇
【Day 22】不靠 Encoder?用 GPT-2 試試翻譯的可能性
系列文
零基礎 AI 入門!從 Wx+b 到熱門模型的完整之路!24
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言